{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 图像分类网络结构搜索-快速开始\n",
    "\n",
    "该教程以图像分类模型MobileNetV2为例，说明如何在cifar10数据集上快速使用[网络结构搜索接口](../api/nas_api.md)。\n",
    "该示例包含以下步骤：\n",
    "\n",
    "1. 导入依赖\n",
    "2. 初始化SANAS搜索实例\n",
    "3. 构建网络\n",
    "4. 启动搜索实验\n",
    "\n",
    "以下章节依次介绍每个步骤的内容。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 导入依赖\n",
    "请确认已正确安装Paddle，导入需要的依赖包。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "import paddleslim as slim\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 初始化SANAS搜索实例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sanas = slim.nas.SANAS(configs=[('MobileNetV2Space')]， server_addr=(\"\", 8337))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 构建网络\n",
    "根据传入的网络结构构造训练program和测试program。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_program(archs):\n",
    "    train_program = fluid.Program()\n",
    "    startup_program = fluid.Program()\n",
    "    with fluid.program_guard(train_program, startup_program):\n",
    "        data = fluid.data(name='data', shape=[None, 3, 32, 32], dtype='float32')\n",
    "        label = fluid.data(name='label', shape=[None, 1], dtype='int64')\n",
    "        output = archs(data)\n",
    "        output = fluid.layers.fc(input=output, size=10)\n",
    "\n",
    "        softmax_out = fluid.layers.softmax(input=output, use_cudnn=False)\n",
    "        cost = fluid.layers.cross_entropy(input=softmax_out, label=label)\n",
    "        avg_cost = fluid.layers.mean(cost)\n",
    "        acc_top1 = fluid.layers.accuracy(input=softmax_out, label=label, k=1)\n",
    "        acc_top5 = fluid.layers.accuracy(input=softmax_out, label=label, k=5)\n",
    "        test_program = fluid.default_main_program().clone(for_test=True)\n",
    "            \n",
    "        optimizer = fluid.optimizer.Adam(learning_rate=0.1)\n",
    "        optimizer.minimize(avg_cost)\n",
    "\n",
    "        place = fluid.CPUPlace()\n",
    "        exe = fluid.Executor(place)\n",
    "        exe.run(startup_program)\n",
    "    return exe, train_program, test_program, (data, label), avg_cost, acc_top1, acc_top5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 启动搜索实验\n",
    "获取每一轮的模型结构并开始训练。该教程中使用FLOPs作为约束条件，搜索实验一共搜索3个step，表示搜索到3个满足条件的模型结构进行训练，每搜索到一个网络结构训练7个epoch。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for step in range(3):\n",
    "    archs = sanas.next_archs()[0]\n",
    "    exe, train_program, test_progarm, inputs, avg_cost, acc_top1, acc_top5 = build_program(archs)\n",
    "\n",
    "    current_flops = slim.analysis.flops(train_program)\n",
    "    if current_flops > 321208544:\n",
    "        continue\n",
    "    \n",
    "    train_reader = paddle.batch(paddle.reader.shuffle(paddle.dataset.cifar.train10(cycle=False),                          buf_size=1024),batch_size=256)\n",
    "    train_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())\n",
    "    test_reader = paddle.batch(paddle.dataset.cifar.test10(cycle=False),\n",
    "               batch_size=256)\n",
    "    test_feeder = fluid.DataFeeder(inputs, fluid.CPUPlace())\n",
    "\n",
    "    outputs = [avg_cost.name, acc_top1.name, acc_top5.name]\n",
    "    for epoch in range(7):\n",
    "        for data in train_reader():\n",
    "            loss, acc1, acc5 = exe.run(train_program, feed=train_feeder.feed(data), fetch_list = outputs)\n",
    "            print(\"TRAIN: loss: {}, acc1: {}, acc5:{}\".format(loss, acc1, acc5))\n",
    "\n",
    "    reward = []\n",
    "    for data in test_reader():\n",
    "        batch_reward = exe.run(test_program, feed=test_feeder.feed(data), fetch_list = outputs)\n",
    "        reward_avg = np.mean(np.array(batch_reward), axis=1)\n",
    "        reward.append(reward_avg)\n",
    "        print(\"TEST: loss: {}, acc1: {}, acc5:{}\".format(batch_reward[0], batch_reward[1], batch_reward[2]))\n",
    "    finally_reward = np.mean(np.array(reward), axis=0)\n",
    "    print(\"FINAL TEST: avg_cost: {}, acc1: {}, acc5: {}\".format(finally_reward[0], finally_reward[1], finally_reward[2]))\n",
    "\n",
    "    sanas.reward(float(finally_reward[1]))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}